{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "from rdkit import Chem\n",
    "import os\n",
    "from typing import Optional\n",
    "from rdkit import Chem"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate docking grid file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculated_docking_grid_sdf(work_path, json_path, pdbid, ligid, add_size=10):\n",
    "    input_path = os.path.join(work_path, f'{pdbid}_{ligid}.sdf')\n",
    "    os.makedirs(json_path, exist_ok=True)\n",
    "    output_grid = os.path.join(json_path, pdbid + '.json')\n",
    "    add_size = add_size\n",
    "    mol = Chem.MolFromMolFile(str(input_path), sanitize=False)\n",
    "    coords = mol.GetConformer(0).GetPositions().astype(np.float32)\n",
    "    min_xyz = [min(coord[i] for coord in coords) for i in range(3)]\n",
    "    max_xyz = [max(coord[i] for coord in coords) for i in range(3)]\n",
    "    center = np.mean(coords, axis=0)\n",
    "    size = [abs(max_xyz[i] - min_xyz[i]) for i in range(3)]\n",
    "    center_x, center_y, center_z = center\n",
    "    size_x, size_y, size_z = size\n",
    "    size_x = size_x + add_size\n",
    "    size_y = size_y + add_size\n",
    "    size_z = size_z + add_size\n",
    "    grid_info = {\n",
    "        \"center_x\": float(center_x),\n",
    "        \"center_y\": float(center_y),\n",
    "        \"center_z\": float(center_z),\n",
    "        \"size_x\": float(size_x),\n",
    "        \"size_y\": float(size_y),\n",
    "        \"size_z\": float(size_z)\n",
    "    }\n",
    "    with open(output_grid, 'w') as f:\n",
    "        json.dump(grid_info, f, indent=4)\n",
    "    print(pdbid)\n",
    "    print('Center: ({:.6f}, {:.6f}, {:.6f})'.format(center_x, center_y, center_z))\n",
    "    print('Size: ({:.6f}, {:.6f}, {:.6f})'.format(size_x, size_y, size_z))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "protein_path = 'eval_sets/posebusters/proteins'\n",
    "ligand_path = 'eval_sets/posebusters/ligands'\n",
    "meta_info_file = 'eval_sets/posebusters/posebuster_set_meta.csv'\n",
    "add_size=10\n",
    "grid_path = f'posebuster428_grid{add_size}'\n",
    "df = pd.read_csv(meta_info_file)\n",
    "pdb_ids = list(df['pdb_code'].values)\n",
    "lig_ids = list(df['lig_code'].values)\n",
    "# generate docking grid json file\n",
    "for pdbid, ligid in tqdm(zip(pdb_ids, lig_ids)):\n",
    "    calculated_docking_grid_sdf(ligand_path, grid_path, pdbid, ligid, add_size=add_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate the input metainfo csv file required for the interface"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(columns=['input_protein', 'input_ligand', 'input_docking_grid', 'output_ligand_name'])\n",
    "input_meta_info_file = 'posebuster_428_one2one.csv'\n",
    "predict_name_suffix = 'predict'\n",
    "for i, item in tqdm(enumerate(zip(pdb_ids,lig_ids))):\n",
    "    pdbid, ligid = item\n",
    "    single_protein_path = os.path.abspath(os.path.join(protein_path, pdbid + '.pdb'))\n",
    "    single_ligand_path = os.path.abspath(os.path.join(ligand_path, f'{pdbid}_{ligid}.sdf'))\n",
    "    single_grid_path = os.path.abspath(os.path.join(grid_path, pdbid + '.json'))\n",
    "    predict_name = f'{pdbid}_{predict_name_suffix}'\n",
    "    df.loc[i] = [single_protein_path, single_ligand_path, single_grid_path, predict_name]\n",
    "print(df.info())\n",
    "print(df.head(3))\n",
    "df.to_csv(input_meta_info_file, index= False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inference: use the model trained on moad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predict_sdf_dir = f'predict_sdf_posebuster428_grid{add_size}'\n",
    "\n",
    "!python demo.py --mode batch_one2one --batch-size 8 --steric-clash-fix --conf-size 10 --cluster \\\n",
    "        --input-batch-file $input_meta_info_file \\\n",
    "        --output-ligand-dir $predict_sdf_dir \\\n",
    "        --model-dir checkpoint_best.pt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cal metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rmsd_func(holo_coords, predict_coords):\n",
    "    if predict_coords is not np.nan:\n",
    "        sz = holo_coords.shape\n",
    "        rmsd = np.sqrt(np.sum((predict_coords - holo_coords)**2) / sz[0])\n",
    "        return rmsd\n",
    "    return 1000.0\n",
    "\n",
    "def rmsd_func_sym(holo_coords: np.ndarray, predict_coords: np.ndarray, mol: Optional[Chem.Mol] = None) -> float:\n",
    "    \"\"\" Symmetric RMSD for molecules. \"\"\"\n",
    "    if predict_coords is not np.nan:\n",
    "        sz = holo_coords.shape\n",
    "        if mol is not None:\n",
    "            # get stereochem-unaware permutations: (P, N)\n",
    "            base_perms = np.array(mol.GetSubstructMatches(mol, uniquify=False))\n",
    "            # filter for valid stereochem only\n",
    "            chem_order = np.array(list(Chem.rdmolfiles.CanonicalRankAtoms(mol, breakTies=False)))\n",
    "            perms_mask = (chem_order[base_perms] == chem_order[None]).sum(-1) == mol.GetNumAtoms()\n",
    "            base_perms = base_perms[perms_mask]\n",
    "            noh_mask = np.array([a.GetAtomicNum() != 1 for a in mol.GetAtoms()])\n",
    "            # (N, 3), (N, 3) -> (P, N, 3), ((), N, 3) -> (P,) -> min((P,))\n",
    "            best_rmsd = np.inf\n",
    "            for perm in base_perms:\n",
    "                rmsd = np.sqrt(np.sum((predict_coords[perm[noh_mask]] - holo_coords) ** 2) / sz[0])\n",
    "                if rmsd < best_rmsd:\n",
    "                    best_rmsd = rmsd\n",
    "\n",
    "            rmsd = best_rmsd\n",
    "        else:\n",
    "            rmsd = np.sqrt(np.sum((predict_coords - holo_coords) ** 2) / sz[0])\n",
    "        return rmsd\n",
    "    return 1000.0\n",
    "\n",
    "def cal_rmsd_metrics(predict_dir, reference_dir, meta_info_file):\n",
    "    df = pd.read_csv(meta_info_file)\n",
    "    pdb_ids = list(df['pdb_code'].values)\n",
    "    lig_ids = list(df['lig_code'].values)\n",
    "    failed_num = 0\n",
    "    rmsd_results, rmsd_sym_results = [], []\n",
    "    for pdb_id, lig_id in tqdm(zip(pdb_ids, lig_ids)):\n",
    "        target_ligand  = os.path.join(reference_dir, f'{pdb_id}_{lig_id}.sdf')\n",
    "        predict_ligand = os.path.join(predict_dir, f'{pdb_id}_{predict_name_suffix}.sdf')\n",
    "        target_supp = Chem.SDMolSupplier(target_ligand)\n",
    "        target_mol = [mol for mol in target_supp if mol][0]\n",
    "        target_mol = Chem.RemoveHs(target_mol)\n",
    "        holo_coords = target_mol.GetConformer().GetPositions().astype(np.float32)\n",
    "        target_atoms = [atom.GetSymbol() for atom in target_mol.GetAtoms()]\n",
    "        predict_mol = Chem.MolFromMolFile(predict_ligand, sanitize=False)\n",
    "        try:\n",
    "            predict_coords = predict_mol.GetConformer().GetPositions().astype(np.float32)\n",
    "        except:\n",
    "            print(f'failed pdb_id: {pdb_id}')\n",
    "            failed_num+=1\n",
    "            continue\n",
    "        predict_atoms = [atom.GetSymbol() for atom in predict_mol.GetAtoms()]\n",
    "        if predict_atoms == target_atoms:\n",
    "            rmsd = rmsd_func(holo_coords, predict_coords)\n",
    "            rmsd_new = rmsd_func_sym(holo_coords, predict_coords, predict_mol)\n",
    "            rmsd_results.append(rmsd)\n",
    "            rmsd_sym_results.append(rmsd_new)\n",
    "        else: \n",
    "            print(f'failed pdb_id: {pdb_id}')\n",
    "            failed_num+=1\n",
    "    rmsd_results = np.array(rmsd_results)\n",
    "    rmsd_sym_results = np.array(rmsd_sym_results)\n",
    "    return rmsd_results, rmsd_sym_results\n",
    "\n",
    "def print_result(rmsd_results):\n",
    "    print('*'*100)\n",
    "    print(f'results length: {len(rmsd_results)}')\n",
    "    print('RMSD < 0.5 : ', np.mean(rmsd_results<0.5))\n",
    "    print('RMSD < 1.0 : ', np.mean(rmsd_results<1.0))\n",
    "    print('RMSD < 1.5 : ', np.mean(rmsd_results<1.5))\n",
    "    print('RMSD < 2.0 : ', np.mean(rmsd_results<2.0))\n",
    "    print('RMSD < 3.0 : ', np.mean(rmsd_results<3.0))\n",
    "    print('RMSD < 5.0 : ', np.mean(rmsd_results<5.0))\n",
    "    print('avg RMSD : ', np.mean(rmsd_results))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rmsd_results, rmsd_sym_results = cal_rmsd_metrics(predict_dir = predict_sdf_dir, \n",
    "                                                  reference_dir=ligand_path, \n",
    "                                                  meta_info_file=meta_info_file)\n",
    "print_result(rmsd_results)\n",
    "print_result(rmsd_sym_results)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
